--- title: Metrics for unpaired image-to-image translation keywords: fastai sidebar: home_sidebar summary: "Defines functionality for implementing metrics for unpaired image-to-image translation, including common metrics like FID, KID, etc." description: "Defines functionality for implementing metrics for unpaired image-to-image translation, including common metrics like FID, KID, etc." nb_path: "nbs/05_metrics.ipynb" ---
set_seed(999, reproducible=True)
This code is based on this implementation and this implementation, adapted to fastai's metric API.
The FrechetInceptionDistance metric works by initializing an Inception model, extracting Inception activation features for each batch of predictions and example images (target), and at the end calculate the statistics and the Frechet distance. Below are test for each of these components.
fid = FrechetInceptionDistance(device='cpu')
size = (224, 224, 3)
arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)]*2
img_like_tensor = torch.from_numpy(np.array(arrays)).float()
test_eq(fid.calc_activations_for_batch(img_like_tensor.permute(0,3,1,2),model=fid.model,device='cpu').shape, (img_like_tensor.shape[0],2048))
class fake_model(nn.Module):
def __init__(self): super(fake_model, self).__init__()
def forward(self,x): return x.mean(dim=(2,3))
size = (4, 4, 3)
arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)]
input_tensor = torch.from_numpy(np.array(arrays)).float()
stats = fid.calculate_activation_statistics(fid.calc_activations_for_batch(input_tensor.permute(0,3,1,2),model=fake_model()))
test_eq(stats[0], np.ones((3,)) * 0.5)
test_eq(stats[1], np.ones((3, 3)) * 0.25)
m1, m2 = np.zeros((2048,)), np.ones((2048,))
sigma = np.eye(2048)
# Given equal covariance, FID is just the squared norm of difference
test_eq(fid.calculate_frechet_distance(m1,sigma,m2,sigma), np.sum((m1 - m2)**2))
class FakeLearner():
def __init__(self):
self.yb = [img_like_tensor.permute(0,3,1,2)]
self.pred = [None, img_like_tensor.permute(0,3,1,2)]
learn = FakeLearner()
%%time
for i in range(5):
fid.accumulate(learn)
print(fid.value)
horse2zebra = untar_data('https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip')
folders = horse2zebra.ls().sorted()
trainA_path = folders[2]
trainB_path = folders[3]
testA_path = folders[0]
testB_path = folders[1]
dls = get_dls(trainA_path, trainB_path,num_A=100)
cycle_gan = CycleGAN(3,3,64)
learn = cycle_learner(dls, cycle_gan,metrics=[FrechetInceptionDistance()],show_img_interval=1)
learn.fit_flat_lin(5,5,2e-4)